Skip to content

[flang][OpenMP] Support for "atomic compare capture"#202315

Open
SunilKuravinakop wants to merge 4 commits into
llvm:mainfrom
SunilKuravinakop:compare_capture
Open

[flang][OpenMP] Support for "atomic compare capture"#202315
SunilKuravinakop wants to merge 4 commits into
llvm:mainfrom
SunilKuravinakop:compare_capture

Conversation

@SunilKuravinakop

Copy link
Copy Markdown
Contributor

Adding support for "!$omp atomic compare capture".

subroutine compare_capture_01(var1, num1, num2, num3)
integer :: var1, num1, num2, num3
!$omp atomic compare capture
num3 = var1
if (var1 == num1) var1 = num2
!$omp end atomic
end subroutine

This also Fixes #202311

@llvmorg-github-actions

llvmorg-github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: SunilKuravinakop

Changes

Adding support for "!$omp atomic compare capture".

subroutine compare_capture_01(var1, num1, num2, num3)
integer :: var1, num1, num2, num3
!$omp atomic compare capture
num3 = var1
if (var1 == num1) var1 = num2
!$omp end atomic
end subroutine

This also Fixes #202311


Patch is 28.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/202315.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/Atomic.cpp (+51-39)
  • (modified) flang/test/Integration/OpenMP/atomic-compare.f90 (+32)
  • (modified) flang/test/Lower/OpenMP/atomic-compare.f90 (+56)
  • (modified) flang/test/Parser/OpenMP/atomic-unparse.f90 (+29)
  • (modified) mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td (+8-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+10)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+177-40)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+48)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+41)
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 4ce0c8f878c48..c0f97aada637d 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -560,16 +560,45 @@ void Fortran::lower::omp::lowerAtomic(
   int action1 = analysis.op1.what & analysis.Action;
   memOrder = makeValidForAction(memOrder, action0, action1, version);
 
+  // --- Shared capture scaffolding ---
+  mlir::Operation *captureOp = nullptr;
+  fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
+  fir::FirOpBuilder::InsertPoint atomicAt, postAt;
+
+  if (construct.IsCapture()) {
+    assert(action0 != analysis.None && action1 != analysis.None &&
+           "Expexcing two actions");
+    (void)action0;
+    (void)action1;
+    captureOp = mlir::omp::AtomicCaptureOp::create(
+        builder, loc, hint, makeMemOrderAttr(converter, memOrder));
+    // Set the non-atomic insertion point to before the atomic.capture.
+    preAt = getInsertionPointBefore(captureOp);
+
+    mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
+    builder.setInsertionPointToEnd(block);
+    // Set the atomic insertion point to before the terminator inside
+    // atomic.capture.
+    mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
+    atomicAt = getInsertionPointBefore(term);
+    postAt = getInsertionPointAfter(captureOp);
+    hint = nullptr;
+    memOrder = std::nullopt;
+  }
+
   if (auto *cond = get(analysis.cond)) {
     // atomic compare: if (x == e) x = d
     // e : expecteVal
     // d : desiredVal
 
-    // Check for compound clauses (fail, capture) that are not yet
+    // Restore insertion point so pre-processing code (e.g. computing
+    // expectedVal) is emitted before the capture op, not after the terminator.
+    builder.restoreInsertionPoint(preAt);
+
+    // Check for compound clause (fail) that is not yet
     // supported with atomic compare.
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
-          return clause.id == llvm::omp::Clause::OMPC_fail ||
-                 clause.id == llvm::omp::Clause::OMPC_capture;
+          return clause.id == llvm::omp::Clause::OMPC_fail;
         })) {
       TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
     }
@@ -617,6 +646,17 @@ void Fortran::lower::omp::lowerAtomic(
       expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
     }
 
+    // If this is a compare+capture, generate the read op first.
+    if (construct.IsCapture()) {
+      assert(get(analysis.op0.assign) && (analysis.op0.what & analysis.Read) &&
+             "Expected a read operation for compare capture");
+      mlir::Operation *readOp = genAtomicRead(
+          converter, semaCtx, loc, stmtCtx, atomAddr, atom,
+          *get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
+      assert(readOp && "Should have created an atomic read operation");
+      builder.setInsertionPointAfter(readOp);
+    }
+
     mlir::UnitAttr weakAttr = nullptr;
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
           return clause.id == llvm::omp::Clause::OMPC_weak;
@@ -685,34 +725,9 @@ void Fortran::lower::omp::lowerAtomic(
     // Generate omp.yield
     mlir::omp::YieldOp::create(builder, loc, newVal);
     builder.setInsertionPointAfter(atomicOp);
-
     // END omp atomic compare
   } else {
-    mlir::Operation *captureOp = nullptr;
-    fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
-    fir::FirOpBuilder::InsertPoint atomicAt, postAt;
-
-    if (construct.IsCapture()) {
-      // Capturing operation.
-      assert(action0 != analysis.None && action1 != analysis.None &&
-             "Expexcing two actions");
-      (void)action0;
-      (void)action1;
-      captureOp = mlir::omp::AtomicCaptureOp::create(
-          builder, loc, hint, makeMemOrderAttr(converter, memOrder));
-      // Set the non-atomic insertion point to before the atomic.capture.
-      preAt = getInsertionPointBefore(captureOp);
-
-      mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
-      builder.setInsertionPointToEnd(block);
-      // Set the atomic insertion point to before the terminator inside
-      // atomic.capture.
-      mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
-      atomicAt = getInsertionPointBefore(term);
-      postAt = getInsertionPointAfter(captureOp);
-      hint = nullptr;
-      memOrder = std::nullopt;
-    } else {
+    if (!construct.IsCapture()) {
       // Non-capturing operation.
       assert(action0 != analysis.None && action1 == analysis.None &&
              "Expexcing single action");
@@ -735,16 +750,13 @@ void Fortran::lower::omp::lowerAtomic(
           *get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
     }
 
-    if (construct.IsCapture()) {
-      // If this is a capture operation, the first/second ops will be inside
-      // of it. Set the insertion point to past the capture op itself.
-      builder.restoreInsertionPoint(postAt);
-    } else {
-      if (secondOp) {
-        builder.setInsertionPointAfter(secondOp);
-      } else {
-        builder.setInsertionPointAfter(firstOp);
-      }
+    if (!construct.IsCapture()) {
+      builder.setInsertionPointAfter(secondOp ? secondOp : firstOp);
     }
   }
+
+  // Shared capture cleanup.
+  if (construct.IsCapture()) {
+    builder.restoreInsertionPoint(postAt);
+  }
 }
diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90
index 249fb0dd8fa64..650f64b80af12 100644
--- a/flang/test/Integration/OpenMP/atomic-compare.f90
+++ b/flang/test/Integration/OpenMP/atomic-compare.f90
@@ -260,3 +260,35 @@ subroutine atomic_compare_weak(x, e, d)
   if (x == e) x = d
 end 
 
+! Integer equality compare+capture: cmpxchg + store old value
+!CHECK-LABEL: define void @atomic_compare_capture_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
+! Compare+capture with clause order reversed: capture compare
+!CHECK-LABEL: define void @atomic_capture_compare_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_capture_compare_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic capture compare
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
index ac70edbed4e60..752a221aa538d 100644
--- a/flang/test/Lower/OpenMP/atomic-compare.f90
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -161,3 +161,59 @@ subroutine atomic_compare_int_eq_weak(x, e, d)
   !$omp atomic compare weak
   if (x .eq. e) x = d
 end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_eq(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           }
+! CHECK:         }
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x .eq. e) x = d
+  !$omp end atomic
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_gt(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           } {{.*}}weak{{.*}}
+! CHECK:         }
+subroutine atomic_compare_capture_int_gt(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture weak
+    v = x
+    if (x > e) x = d
+  !$omp end atomic
+end
diff --git a/flang/test/Parser/OpenMP/atomic-unparse.f90 b/flang/test/Parser/OpenMP/atomic-unparse.f90
index 4f3cf0eac0338..dc0cc1a62f6c2 100644
--- a/flang/test/Parser/OpenMP/atomic-unparse.f90
+++ b/flang/test/Parser/OpenMP/atomic-unparse.f90
@@ -192,6 +192,26 @@ program main
       i = j
    end if
 
+!COMPARE CAPTURE
+!$omp atomic compare capture
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare weak
+   k = i
+   if (i < j) then
+      i = k
+   end if
+!$omp end atomic
+
 !ATOMIC
 !$omp atomic
    i = j
@@ -296,6 +316,15 @@ end program main
 !CHECK: !$OMP ATOMIC WEAK COMPARE
 !CHECK: !$OMP ATOMIC COMPARE SEQ_CST WEAK
 
+!COMPARE CAPTURE
+
+!CHECK: !$OMP ATOMIC COMPARE CAPTURE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE WEAK
+!CHECK: !$OMP END ATOMIC
+
 !ATOMIC
 !CHECK: !$OMP ATOMIC
 !CHECK: !$OMP ATOMIC SEQ_CST
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
index abb21705b3c1c..8c9015f05bb72 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
@@ -289,10 +289,12 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
         auto secondReadStmt = dyn_cast<AtomicReadOpInterface>(secondOp);
         auto secondUpdateStmt = dyn_cast<AtomicUpdateOpInterface>(secondOp);
         auto secondWriteStmt = dyn_cast<AtomicWriteOpInterface>(secondOp);
+        auto secondCompareStmt = dyn_cast<AtomicCompareOpInterface>(secondOp);
 
         if (!((firstUpdateStmt && secondReadStmt) ||
               (firstReadStmt && secondUpdateStmt) ||
-              (firstReadStmt && secondWriteStmt)))
+              (firstReadStmt && secondWriteStmt) ||
+              (firstReadStmt && secondCompareStmt)))
           return ops.front().emitError()
                 << "invalid sequence of operations in the capture region";
         if (firstUpdateStmt && secondReadStmt &&
@@ -310,6 +312,11 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
           return firstReadStmt.emitError()
                 << "captured variable in atomic.read must be updated in "
                     "second operation";
+        if (firstReadStmt && secondCompareStmt &&
+            firstReadStmt.getX() != secondCompareStmt.getX())
+          return firstReadStmt.emitError()
+                << "captured variable in atomic.read must be updated in "
+                    "second operation";
 
         return mlir::success();
       }]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0962b330e2f23..1241abc10298f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1928,6 +1928,12 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
         omp.atomic.write ...
         omp.terminator
       }
+
+      omp.atomic.capture {
+        omp.atomic.read ...
+        omp.atomic.compare ...
+        omp.terminator
+      }
     ```
   }] # clausesDescription;
 
@@ -1946,6 +1952,10 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
     /// Returns the `atomic.update` operation inside the region, if any.
     /// Otherwise, it returns nullptr.
     AtomicUpdateOp getAtomicUpdateOp();
+
+    /// Returns the `atomic.compare` operation inside the region, if any.
+    /// Otherwise, it returns nullptr.
+    AtomicCompareOp getAtomicCompareOp();
   }] # clausesExtraClassDeclaration;
 
   let hasRegionVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db5fd8f2e3230..0eafd0a267b97 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4615,6 +4615,12 @@ AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
   return dyn_cast<AtomicUpdateOp>(getSecondOp());
 }
 
+AtomicCompareOp AtomicCaptureOp::getAtomicCompareOp() {
+  if (auto op = dyn_cast<AtomicCompareOp>(getFirstOp()))
+    return op;
+  return dyn_cast<AtomicCompareOp>(getSecondOp());
+}
+
 LogicalResult AtomicCaptureOp::verify() {
   return verifySynchronizationHint(*this, getHint());
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6f93ad231cfac..c5a07a7dc6cb2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4795,6 +4795,43 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
   return success();
 }
 
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::ICmpPredicate::slt:
+  case LLVM::ICmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::ICmpPredicate::sgt:
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::FCmpPredicate::oeq:
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::FCmpPredicate::olt:
+  case LLVM::FCmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::FCmpPredicate::ogt:
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
 static LogicalResult
 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
                         llvm::IRBuilderBase &builder,
@@ -4803,13 +4840,150 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
   if (failed(checkImplementationStatus(*atomicCaptureOp)))
     return failure();
 
+  omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
+  omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
+  omp::AtomicCompareOp atomicCompareOp = atomicCaptureOp.getAtomicCompareOp();
+
+  // If the capture contains an atomic.compare, delegate to
+  // createAtomicCompare with the capture variable (V) set.
+  if (atomicCompareOp) {
+    omp::AtomicReadOp atomicReadOp = atomicCaptureOp.getAtomicReadOp();
+    assert(atomicReadOp && "expected atomic.read in capture+compare");
+
+    Region &region = atomicCompareOp.getRegion();
+    Block &block = region.front();
+
+    llvm::Type *llvmXElementType =
+        moduleTranslation.convertType(block.getArgument(0).getType());
+    llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+    llvm::Value *llvmV = moduleTranslation.lookupValue(atomicReadOp.getV());
+
+    bool isSigned = false;
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {
+        llvmX, llvmXElementType, isSigned, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {
+        llvmV, llvmXElementType, /*isSigned=*/false, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicR = {nullptr, nullptr, false,
+                                                        false};
+
+    llvm::AtomicOrdering atomicOrdering =
+        convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
+
+    // Pre-translate non-pattern operations inside the compare region.
+    auto isAtomicComparePatternOp = [](Operation &op) {
+      return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
+                       LLVM::OrOp>(op);
+    };
+    for (Operation &op : block.without_terminator()) {
+      if (isAtomicComparePatternOp(op))
+        continue;
+      bool allOperandsMapped =
+          llvm::all_of(op.getOperands(), [&](mlir::Value v) {
+            return moduleTranslation.lookupValue(v) != nullptr;
+          });
+      if (!allOperandsMapped)
+        continue;
+      if (failed(moduleTranslation.convertOperation(op, builder)))
+        return atomicCompareOp.emitError(
+            "failed to translate operation inside atomic compare region");
+    }
+
+    auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+      if (llvm::Value *existing = moduleTranslation.lookupValue(val))
+        return existing;
+      if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+        if (loadOp->getParentRegion() == &region) {
+          llvm::Value *loadAddr =
+              moduleTranslation.lookupValue(loadOp.getAddr());
+          if (!loadAddr)
+            return nullptr;
+          llvm::Type *loadType =
+              moduleTranslation.convertType(loadOp.getResult().getType());
+          return builder.CreateLoad(loadType, loadAddr);
+        }
+      }
+      return nullptr;
+    };
+
+    // Extract comparison predicate, eVal, and dVal from the region.
+    llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+    llvm::Value *eVal = nullptr;
+    llvm::Value *dVal = nullptr;
+    bool isXBinopExpr = false;
+
+    for (Operation &op : block.getOperations()) {
+      if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
+        auto maybeOp =
+            convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
+        if (!maybeOp)
+     ...
[truncated]

@llvmorg-github-actions

Copy link
Copy Markdown

@llvm/pr-subscribers-mlir-openacc

Author: SunilKuravinakop

Changes

Adding support for "!$omp atomic compare capture".

subroutine compare_capture_01(var1, num1, num2, num3)
integer :: var1, num1, num2, num3
!$omp atomic compare capture
num3 = var1
if (var1 == num1) var1 = num2
!$omp end atomic
end subroutine

This also Fixes #202311


Patch is 28.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/202315.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/Atomic.cpp (+51-39)
  • (modified) flang/test/Integration/OpenMP/atomic-compare.f90 (+32)
  • (modified) flang/test/Lower/OpenMP/atomic-compare.f90 (+56)
  • (modified) flang/test/Parser/OpenMP/atomic-unparse.f90 (+29)
  • (modified) mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td (+8-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+10)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+6)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+177-40)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+48)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+41)
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 4ce0c8f878c48..c0f97aada637d 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -560,16 +560,45 @@ void Fortran::lower::omp::lowerAtomic(
   int action1 = analysis.op1.what & analysis.Action;
   memOrder = makeValidForAction(memOrder, action0, action1, version);
 
+  // --- Shared capture scaffolding ---
+  mlir::Operation *captureOp = nullptr;
+  fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
+  fir::FirOpBuilder::InsertPoint atomicAt, postAt;
+
+  if (construct.IsCapture()) {
+    assert(action0 != analysis.None && action1 != analysis.None &&
+           "Expexcing two actions");
+    (void)action0;
+    (void)action1;
+    captureOp = mlir::omp::AtomicCaptureOp::create(
+        builder, loc, hint, makeMemOrderAttr(converter, memOrder));
+    // Set the non-atomic insertion point to before the atomic.capture.
+    preAt = getInsertionPointBefore(captureOp);
+
+    mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
+    builder.setInsertionPointToEnd(block);
+    // Set the atomic insertion point to before the terminator inside
+    // atomic.capture.
+    mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
+    atomicAt = getInsertionPointBefore(term);
+    postAt = getInsertionPointAfter(captureOp);
+    hint = nullptr;
+    memOrder = std::nullopt;
+  }
+
   if (auto *cond = get(analysis.cond)) {
     // atomic compare: if (x == e) x = d
     // e : expecteVal
     // d : desiredVal
 
-    // Check for compound clauses (fail, capture) that are not yet
+    // Restore insertion point so pre-processing code (e.g. computing
+    // expectedVal) is emitted before the capture op, not after the terminator.
+    builder.restoreInsertionPoint(preAt);
+
+    // Check for compound clause (fail) that is not yet
     // supported with atomic compare.
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
-          return clause.id == llvm::omp::Clause::OMPC_fail ||
-                 clause.id == llvm::omp::Clause::OMPC_capture;
+          return clause.id == llvm::omp::Clause::OMPC_fail;
         })) {
       TODO(loc, "Compound clauses of OpenMP ATOMIC COMPARE");
     }
@@ -617,6 +646,17 @@ void Fortran::lower::omp::lowerAtomic(
       expectedVal = builder.createConvert(loc, elemTypeOfX, expectedVal);
     }
 
+    // If this is a compare+capture, generate the read op first.
+    if (construct.IsCapture()) {
+      assert(get(analysis.op0.assign) && (analysis.op0.what & analysis.Read) &&
+             "Expected a read operation for compare capture");
+      mlir::Operation *readOp = genAtomicRead(
+          converter, semaCtx, loc, stmtCtx, atomAddr, atom,
+          *get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
+      assert(readOp && "Should have created an atomic read operation");
+      builder.setInsertionPointAfter(readOp);
+    }
+
     mlir::UnitAttr weakAttr = nullptr;
     if (llvm::any_of(clauses, [](const omp::Clause &clause) {
           return clause.id == llvm::omp::Clause::OMPC_weak;
@@ -685,34 +725,9 @@ void Fortran::lower::omp::lowerAtomic(
     // Generate omp.yield
     mlir::omp::YieldOp::create(builder, loc, newVal);
     builder.setInsertionPointAfter(atomicOp);
-
     // END omp atomic compare
   } else {
-    mlir::Operation *captureOp = nullptr;
-    fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
-    fir::FirOpBuilder::InsertPoint atomicAt, postAt;
-
-    if (construct.IsCapture()) {
-      // Capturing operation.
-      assert(action0 != analysis.None && action1 != analysis.None &&
-             "Expexcing two actions");
-      (void)action0;
-      (void)action1;
-      captureOp = mlir::omp::AtomicCaptureOp::create(
-          builder, loc, hint, makeMemOrderAttr(converter, memOrder));
-      // Set the non-atomic insertion point to before the atomic.capture.
-      preAt = getInsertionPointBefore(captureOp);
-
-      mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
-      builder.setInsertionPointToEnd(block);
-      // Set the atomic insertion point to before the terminator inside
-      // atomic.capture.
-      mlir::Operation *term = mlir::omp::TerminatorOp::create(builder, loc);
-      atomicAt = getInsertionPointBefore(term);
-      postAt = getInsertionPointAfter(captureOp);
-      hint = nullptr;
-      memOrder = std::nullopt;
-    } else {
+    if (!construct.IsCapture()) {
       // Non-capturing operation.
       assert(action0 != analysis.None && action1 == analysis.None &&
              "Expexcing single action");
@@ -735,16 +750,13 @@ void Fortran::lower::omp::lowerAtomic(
           *get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
     }
 
-    if (construct.IsCapture()) {
-      // If this is a capture operation, the first/second ops will be inside
-      // of it. Set the insertion point to past the capture op itself.
-      builder.restoreInsertionPoint(postAt);
-    } else {
-      if (secondOp) {
-        builder.setInsertionPointAfter(secondOp);
-      } else {
-        builder.setInsertionPointAfter(firstOp);
-      }
+    if (!construct.IsCapture()) {
+      builder.setInsertionPointAfter(secondOp ? secondOp : firstOp);
     }
   }
+
+  // Shared capture cleanup.
+  if (construct.IsCapture()) {
+    builder.restoreInsertionPoint(postAt);
+  }
 }
diff --git a/flang/test/Integration/OpenMP/atomic-compare.f90 b/flang/test/Integration/OpenMP/atomic-compare.f90
index 249fb0dd8fa64..650f64b80af12 100644
--- a/flang/test/Integration/OpenMP/atomic-compare.f90
+++ b/flang/test/Integration/OpenMP/atomic-compare.f90
@@ -260,3 +260,35 @@ subroutine atomic_compare_weak(x, e, d)
   if (x == e) x = d
 end 
 
+! Integer equality compare+capture: cmpxchg + store old value
+!CHECK-LABEL: define void @atomic_compare_capture_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
+! Compare+capture with clause order reversed: capture compare
+!CHECK-LABEL: define void @atomic_capture_compare_int_eq_(
+!CHECK-SAME: ptr noalias %[[X:.*]], ptr noalias %[[E:.*]], ptr noalias %[[D:.*]], ptr noalias %[[V:.*]])
+!CHECK: %[[EVAL:.*]] = load i32, ptr %[[E]]
+!CHECK: %[[DVAL:.*]] = load i32, ptr %[[D]]
+!CHECK: %[[RES:.*]] = cmpxchg ptr %[[X]], i32 %[[EVAL]], i32 %[[DVAL]] monotonic monotonic
+!CHECK: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
+!CHECK: store i32 %[[OLD]], ptr %[[V]]
+subroutine atomic_capture_compare_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic capture compare
+    v = x
+    if (x == e) x = d
+  !$omp end atomic
+end
+
diff --git a/flang/test/Lower/OpenMP/atomic-compare.f90 b/flang/test/Lower/OpenMP/atomic-compare.f90
index ac70edbed4e60..752a221aa538d 100644
--- a/flang/test/Lower/OpenMP/atomic-compare.f90
+++ b/flang/test/Lower/OpenMP/atomic-compare.f90
@@ -161,3 +161,59 @@ subroutine atomic_compare_int_eq_weak(x, e, d)
   !$omp atomic compare weak
   if (x .eq. e) x = d
 end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_eq(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi eq, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           }
+! CHECK:         }
+subroutine atomic_compare_capture_int_eq(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture
+    v = x
+    if (x .eq. e) x = d
+  !$omp end atomic
+end
+
+! CHECK-LABEL: func.func @_QPatomic_compare_capture_int_gt(
+! CHECK-SAME:    %[[X:.*]]: !fir.ref<i32> {fir.bindc_name = "x"},
+! CHECK-SAME:    %[[E:.*]]: !fir.ref<i32> {fir.bindc_name = "e"},
+! CHECK-SAME:    %[[D:.*]]: !fir.ref<i32> {fir.bindc_name = "d"},
+! CHECK-SAME:    %[[V:.*]]: !fir.ref<i32> {fir.bindc_name = "v"})
+! CHECK:         %[[D_DECL:.*]]:2 = hlfir.declare %[[D]] {{.*}}
+! CHECK:         %[[E_DECL:.*]]:2 = hlfir.declare %[[E]] {{.*}}
+! CHECK:         %[[V_DECL:.*]]:2 = hlfir.declare %[[V]] {{.*}}
+! CHECK:         %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {{.*}}
+! CHECK:         %[[EVAL:.*]] = fir.load %[[E_DECL]]#0 : !fir.ref<i32>
+! CHECK:         omp.atomic.capture memory_order(relaxed) {
+! CHECK:           omp.atomic.read %[[V_DECL]]#0 = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
+! CHECK:           omp.atomic.compare %[[X_DECL]]#0 : !fir.ref<i32> {
+! CHECK:           ^bb0(%[[XVAL:.*]]: i32):
+! CHECK:             %[[CMP:.*]] = arith.cmpi sgt, %[[XVAL]], %[[EVAL]] : i32
+! CHECK:             %[[DVAL:.*]] = fir.load %[[D_DECL]]#0 : !fir.ref<i32>
+! CHECK:             %[[SEL:.*]] = arith.select %[[CMP]], %[[DVAL]], %[[XVAL]] : i32
+! CHECK:             omp.yield(%[[SEL]] : i32)
+! CHECK:           } {{.*}}weak{{.*}}
+! CHECK:         }
+subroutine atomic_compare_capture_int_gt(x, e, d, v)
+  integer :: x, e, d, v
+  !$omp atomic compare capture weak
+    v = x
+    if (x > e) x = d
+  !$omp end atomic
+end
diff --git a/flang/test/Parser/OpenMP/atomic-unparse.f90 b/flang/test/Parser/OpenMP/atomic-unparse.f90
index 4f3cf0eac0338..dc0cc1a62f6c2 100644
--- a/flang/test/Parser/OpenMP/atomic-unparse.f90
+++ b/flang/test/Parser/OpenMP/atomic-unparse.f90
@@ -192,6 +192,26 @@ program main
       i = j
    end if
 
+!COMPARE CAPTURE
+!$omp atomic compare capture
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare
+   k = i
+   if (i .eq. j) then
+      i = k
+   end if
+!$omp end atomic
+!$omp atomic capture compare weak
+   k = i
+   if (i < j) then
+      i = k
+   end if
+!$omp end atomic
+
 !ATOMIC
 !$omp atomic
    i = j
@@ -296,6 +316,15 @@ end program main
 !CHECK: !$OMP ATOMIC WEAK COMPARE
 !CHECK: !$OMP ATOMIC COMPARE SEQ_CST WEAK
 
+!COMPARE CAPTURE
+
+!CHECK: !$OMP ATOMIC COMPARE CAPTURE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE
+!CHECK: !$OMP END ATOMIC
+!CHECK: !$OMP ATOMIC CAPTURE COMPARE WEAK
+!CHECK: !$OMP END ATOMIC
+
 !ATOMIC
 !CHECK: !$OMP ATOMIC
 !CHECK: !$OMP ATOMIC SEQ_CST
diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
index abb21705b3c1c..8c9015f05bb72 100644
--- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td
@@ -289,10 +289,12 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
         auto secondReadStmt = dyn_cast<AtomicReadOpInterface>(secondOp);
         auto secondUpdateStmt = dyn_cast<AtomicUpdateOpInterface>(secondOp);
         auto secondWriteStmt = dyn_cast<AtomicWriteOpInterface>(secondOp);
+        auto secondCompareStmt = dyn_cast<AtomicCompareOpInterface>(secondOp);
 
         if (!((firstUpdateStmt && secondReadStmt) ||
               (firstReadStmt && secondUpdateStmt) ||
-              (firstReadStmt && secondWriteStmt)))
+              (firstReadStmt && secondWriteStmt) ||
+              (firstReadStmt && secondCompareStmt)))
           return ops.front().emitError()
                 << "invalid sequence of operations in the capture region";
         if (firstUpdateStmt && secondReadStmt &&
@@ -310,6 +312,11 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> {
           return firstReadStmt.emitError()
                 << "captured variable in atomic.read must be updated in "
                     "second operation";
+        if (firstReadStmt && secondCompareStmt &&
+            firstReadStmt.getX() != secondCompareStmt.getX())
+          return firstReadStmt.emitError()
+                << "captured variable in atomic.read must be updated in "
+                    "second operation";
 
         return mlir::success();
       }]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0962b330e2f23..1241abc10298f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1928,6 +1928,12 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
         omp.atomic.write ...
         omp.terminator
       }
+
+      omp.atomic.capture {
+        omp.atomic.read ...
+        omp.atomic.compare ...
+        omp.terminator
+      }
     ```
   }] # clausesDescription;
 
@@ -1946,6 +1952,10 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", traits = [
     /// Returns the `atomic.update` operation inside the region, if any.
     /// Otherwise, it returns nullptr.
     AtomicUpdateOp getAtomicUpdateOp();
+
+    /// Returns the `atomic.compare` operation inside the region, if any.
+    /// Otherwise, it returns nullptr.
+    AtomicCompareOp getAtomicCompareOp();
   }] # clausesExtraClassDeclaration;
 
   let hasRegionVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db5fd8f2e3230..0eafd0a267b97 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4615,6 +4615,12 @@ AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
   return dyn_cast<AtomicUpdateOp>(getSecondOp());
 }
 
+AtomicCompareOp AtomicCaptureOp::getAtomicCompareOp() {
+  if (auto op = dyn_cast<AtomicCompareOp>(getFirstOp()))
+    return op;
+  return dyn_cast<AtomicCompareOp>(getSecondOp());
+}
+
 LogicalResult AtomicCaptureOp::verify() {
   return verifySynchronizationHint(*this, getHint());
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6f93ad231cfac..c5a07a7dc6cb2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4795,6 +4795,43 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
   return success();
 }
 
+/// Helper to extract the OMPAtomicCompareOp from an integer comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertICmpPredicateToAtomicCompareOp(LLVM::ICmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::ICmpPredicate::eq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::ICmpPredicate::slt:
+  case LLVM::ICmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::ICmpPredicate::sgt:
+  case LLVM::ICmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
+/// Helper to extract the OMPAtomicCompareOp from a floating-point comparison
+/// predicate. Returns std::nullopt for unsupported predicates.
+static std::optional<llvm::omp::OMPAtomicCompareOp>
+convertFCmpPredicateToAtomicCompareOp(LLVM::FCmpPredicate predicate) {
+  switch (predicate) {
+  case LLVM::FCmpPredicate::oeq:
+  case LLVM::FCmpPredicate::ueq:
+    return llvm::omp::OMPAtomicCompareOp::EQ;
+  case LLVM::FCmpPredicate::olt:
+  case LLVM::FCmpPredicate::ult:
+    return llvm::omp::OMPAtomicCompareOp::MIN;
+  case LLVM::FCmpPredicate::ogt:
+  case LLVM::FCmpPredicate::ugt:
+    return llvm::omp::OMPAtomicCompareOp::MAX;
+  default:
+    return std::nullopt;
+  }
+}
+
 static LogicalResult
 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
                         llvm::IRBuilderBase &builder,
@@ -4803,13 +4840,150 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
   if (failed(checkImplementationStatus(*atomicCaptureOp)))
     return failure();
 
+  omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
+  omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
+  omp::AtomicCompareOp atomicCompareOp = atomicCaptureOp.getAtomicCompareOp();
+
+  // If the capture contains an atomic.compare, delegate to
+  // createAtomicCompare with the capture variable (V) set.
+  if (atomicCompareOp) {
+    omp::AtomicReadOp atomicReadOp = atomicCaptureOp.getAtomicReadOp();
+    assert(atomicReadOp && "expected atomic.read in capture+compare");
+
+    Region &region = atomicCompareOp.getRegion();
+    Block &block = region.front();
+
+    llvm::Type *llvmXElementType =
+        moduleTranslation.convertType(block.getArgument(0).getType());
+    llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
+    llvm::Value *llvmV = moduleTranslation.lookupValue(atomicReadOp.getV());
+
+    bool isSigned = false;
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {
+        llvmX, llvmXElementType, isSigned, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {
+        llvmV, llvmXElementType, /*isSigned=*/false, /*IsVolatile=*/false};
+    llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicR = {nullptr, nullptr, false,
+                                                        false};
+
+    llvm::AtomicOrdering atomicOrdering =
+        convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
+
+    // Pre-translate non-pattern operations inside the compare region.
+    auto isAtomicComparePatternOp = [](Operation &op) {
+      return llvm::isa<LLVM::ICmpOp, LLVM::FCmpOp, LLVM::SelectOp, LLVM::AndOp,
+                       LLVM::OrOp>(op);
+    };
+    for (Operation &op : block.without_terminator()) {
+      if (isAtomicComparePatternOp(op))
+        continue;
+      bool allOperandsMapped =
+          llvm::all_of(op.getOperands(), [&](mlir::Value v) {
+            return moduleTranslation.lookupValue(v) != nullptr;
+          });
+      if (!allOperandsMapped)
+        continue;
+      if (failed(moduleTranslation.convertOperation(op, builder)))
+        return atomicCompareOp.emitError(
+            "failed to translate operation inside atomic compare region");
+    }
+
+    auto materializeValue = [&](mlir::Value val) -> llvm::Value * {
+      if (llvm::Value *existing = moduleTranslation.lookupValue(val))
+        return existing;
+      if (auto loadOp = val.getDefiningOp<LLVM::LoadOp>()) {
+        if (loadOp->getParentRegion() == &region) {
+          llvm::Value *loadAddr =
+              moduleTranslation.lookupValue(loadOp.getAddr());
+          if (!loadAddr)
+            return nullptr;
+          llvm::Type *loadType =
+              moduleTranslation.convertType(loadOp.getResult().getType());
+          return builder.CreateLoad(loadType, loadAddr);
+        }
+      }
+      return nullptr;
+    };
+
+    // Extract comparison predicate, eVal, and dVal from the region.
+    llvm::omp::OMPAtomicCompareOp compareOp = llvm::omp::OMPAtomicCompareOp::EQ;
+    llvm::Value *eVal = nullptr;
+    llvm::Value *dVal = nullptr;
+    bool isXBinopExpr = false;
+
+    for (Operation &op : block.getOperations()) {
+      if (auto icmpOp = dyn_cast<LLVM::ICmpOp>(op)) {
+        auto maybeOp =
+            convertICmpPredicateToAtomicCompareOp(icmpOp.getPredicate());
+        if (!maybeOp)
+     ...
[truncated]

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds end-to-end support for OpenMP atomic compare capture in Flang/MLIR, enabling lowering and LLVM IR translation for constructs like !$omp atomic compare capture and addressing #202311.

Changes:

  • Extend omp.atomic.capture to allow an {atomic.read, atomic.compare} pairing (including {weak}).
  • Teach OpenMP-to-LLVM IR translation to lower compare+capture into cmpxchg + old-value capture store.
  • Add MLIR and Flang parser/lowering/integration tests covering compare+capture (and weak / clause-order variants).

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
mlir/test/Target/LLVMIR/openmp-llvm.mlir Adds LLVM IR FileCheck coverage for compare+capture (and weak) lowering.
mlir/test/Dialect/OpenMP/ops.mlir Adds dialect-level parsing/printing checks for omp.atomic.capture containing atomic.compare.
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp Implements translation support for atomic compare+capture via createAtomicCompare.
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp Adds AtomicCaptureOp::getAtomicCompareOp() helper.
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td Documents/declares atomic.compare usage inside atomic.capture and exposes accessor.
mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td Updates capture verifier to accept {read, compare} sequences.
flang/test/Parser/OpenMP/atomic-unparse.f90 Adds unparse coverage for atomic compare capture (including clause order/weak).
flang/test/Lower/OpenMP/atomic-compare.f90 Adds lowering checks emitting omp.atomic.capture { read; compare } for compare+capture.
flang/test/Integration/OpenMP/atomic-compare.f90 Adds integration IR checks validating cmpxchg + old-value store for compare+capture.
flang/lib/Lower/OpenMP/Atomic.cpp Lowers atomic compare capture by emitting an atomic.read followed by atomic.compare inside atomic.capture.

Comment on lines +4868 to +4871
llvm::Type *llvmXElementType =
moduleTranslation.convertType(block.getArgument(0).getType());
llvm::Value *llvmX = moduleTranslation.lookupValue(atomicCompareOp.getX());
llvm::Value *llvmV = moduleTranslation.lookupValue(atomicReadOp.getV());
Comment on lines +317 to +319
return firstReadStmt.emitError()
<< "captured variable in atomic.read must be updated in "
"second operation";
Comment on lines +569 to +570
assert(action0 != analysis.None && action1 != analysis.None &&
"Expexcing two actions");
Comment on lines 732 to 733
assert(action0 != analysis.None && action1 == analysis.None &&
"Expexcing single action");
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[flang][OpenMP] Support for "atomic compare capture"

2 participants